P6287 [COCI2016-2017#1] Mag

COCI2016-2017#1题解

此题有一个神仙结论:最优链上最多包含一个‘2’,且‘2’不能在链的两边。 至于证明过于简单但很麻烦(其实是我懒),给大家推荐一个博客,点这

有了这个结论之后,我们就可以愉快的开始树形dp了。

我们用dp1[u]dp1[u]表示包含u节点的最长1链,dp2[u]dp2[u]表示包含u节点的含一个2,其余全是1的最长链。
那么,我们不难得到:

1.当u为1时,dp1[u]=max(dp1[v])+1dp1[u]=max(dp1[v])+1
                    dp2[u]=max(dp2[v])+1~~~~~~~~~~~~~~~~~~~~dp2[u]=max(dp2[v])+1

2.当u为2时,dp2[u]=max(dp1[v])+1dp2[u]=max(dp1[v])+1

我们再用f11f1_1表示最大的dp1[v]dp1[v]的下标,f12f1_2表示次大的dp1[v]dp1[v]的下标,f21f2_1表示最大的dp2[v]dp2[v]的下标

每次更新dp1[u],dp2[u]dp1[u],dp2[u]时,维护上述3个值。

最后我们更新答案,可能会出现以下情况:

1.当u的值为1时

1.f11=f21f1_1=f2_1,说明最长的含2链包含了最长的1链,更新答案为Ans=2dp1[f12]+dp2[f21]+1Ans=\frac{2}{dp1[f1_2]+dp2[f2_1]+1}

2.不存在该上关系但有含2,更新答案为Ans=2dp1[f11]+dp2[f21]+1Ans=\frac{2}{dp1[f1_1]+dp2[f2_1]+1}

3.答案的链不含2,更新答案为Ans=1dp1[f11]+dp1[f12]+1Ans=\frac{1}{dp1[f1_1]+dp1[f1_2]+1}

1.当u的值为2时

最长的含2链为Ans=2dp1[f11]+dp1[f12]+1Ans=\frac{2}{dp1[f1_1]+dp1[f1_2]+1}

#include <cstdio>
#include <vector>
#include <algorithm>
#include <iostream>
using namespace std;
#define LL long long
#define INF 1e15

const int MAXN = 1000000;
int n,a,b;
LL p = INF , q = 1 , Mag[ MAXN + 5 ];
vector< int > Graph[ MAXN + 5 ];

void update( LL a , LL b ) {
	if( a * q < b * p ) p = a , q = b;		//更新时为了避免除法,移项变为乘法
	return;
}
LL f1[ MAXN + 5 ] , f2[ MAXN + 5 ];
void dfs( int u , int fa ) {
	int f1_1 = 0 , f1_2 = 0 , f2_1 = 0;
	if( Mag[ u ] == 1 ) f1[ u ] = 1;
	if( Mag[ u ] == 2 ) f2[ u ] = 1;
	
	for( int i = 0 ; i < Graph[ u ].size( ) ; i ++ ) {
		int v = Graph[ u ][ i ];
		if( v == fa ) continue;
		dfs( v , u );
		
		if( Mag[ u ] == 1 ) {
			f1[ u ] = max( f1[ u ] , f1[ v ] + 1 );
			f2[ u ] = max( f2[ u ] , f2[ v ] + 1 );
		}
		else if( Mag[ u ] == 2 )
			f2[ u ] = max( f2[ u ] , f1[ v ] + 1 );
		
		if( f1[ f1_1 ] < f1[ v ] ) {
			f1_2 = f1_1;
			f1_1 = v;
		}
		else if( f1[ f1_2 ] < f1[ v ] )
			f1_2 = v;
		if( f2[ f2_1 ] < f2[ v ] )
			f2_1 = v;
	}
	if( Mag[ u ] == 1 ) {
		if( f1_1 == f2_1 )
			update( 2 , f1[ f1_2 ] + f2[ f2_1 ] + 1 );
		else
			update( 2 , f1[ f1_1 ] + f2[ f2_1 ] + 1 );
		update( 1 , f1[ f1_1 ] + f1[ f1_2 ] + 1 );
	}
	if( Mag[ u ] == 2 )
		update( 2 , f1[ f1_1 ] + f1[ f1_2 ] + 1 );
}
int main( ) {
	scanf("%d",&n);
	for( int i = 1 ; i <= n - 1 ; i ++ ) {
		scanf("%d %d",&a,&b);
		Graph[ a ].push_back( b );
		Graph[ b ].push_back( a );
	}
	for( int i = 1 ; i <= n ; i ++ ) {
		scanf("%lld",&Mag[ i ]);
		p = min( Mag[ i ] , p );
	}
	dfs( 1 , 0 );
	
	LL r = __gcd( p , q );
	printf("%lld/%lld\n", p / r , q / r );
	return 0;
}